package com.chatgpt.api.websocket.handler;

import com.chatgpt.api.websocket.ChatMessage;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.*;

import java.io.IOException;

@Component
@Slf4j
public class ChatHandler implements WebSocketHandler {
    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {

    }

    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
        log.info("参数列表: {}", session.getAttributes());
        log.info("收到客户端消息:{}", message.getPayload());
        // TODO: 这里处理消息逻辑，将消息推送到大语言模型的 API，然后把输出流重定向到 websocket 响应
        sendMessage(session, "这是来自 Java API 端的消息！");
    }

    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {

    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {

    }

    @Override
    public boolean supportsPartialMessages() {
        return false;
    }

    public void sendMessage(WebSocketSession session, String message) throws IOException {
        ChatMessage start = new ChatMessage("start", "");
        ChatMessage chunk = new ChatMessage("middle", message);
        ChatMessage end = new ChatMessage("end", "");
        session.sendMessage(new BinaryMessage(start.toByte()));
        session.sendMessage(new BinaryMessage(chunk.toByte()));
        session.sendMessage(new BinaryMessage(end.toByte()));
    }

    public void sendChunkMessage(WebSocketSession session, String message) throws IOException {
        ChatMessage chunk = new ChatMessage("middle", message);
        session.sendMessage(new BinaryMessage(chunk.toByte()));
    }
}
